Skip to content

Basic all_gather implementation#1663

Merged
polvalente merged 17 commits intoelixir-nx:mainfrom
Chapaman:all_gather-implementation
Feb 13, 2026
Merged

Basic all_gather implementation#1663
polvalente merged 17 commits intoelixir-nx:mainfrom
Chapaman:all_gather-implementation

Conversation

@Chapaman
Copy link
Contributor

Implements Nx.Defn.Kernel.all_gather/2 to gather sharded tensor data across mesh partitions during distributed execution.
Changes
Nx
Add all_gather/2 in defn/kernel.ex and defn/expr.ex with sharding semantics
Add evaluator support for all_gather in defn/evaluator.ex
EXLA
Lower all_gather to stablehlo.all_gather in defn.ex and mlir/value.ex
Test
EXLA.Defn.ShardingTest: "generates correct MLIR with all_gather" checks MLIR generation and shard_jit output across a 2×2 mesh along axis 0 and 1

@Chapaman Chapaman changed the title Basic gall_gather implementation Basic all_gather implementation Jan 30, 2026
Comment on lines 1481 to 1489
Value.all_gather(
[tensor],
expr_to_typespec(ans),
all_gather_dim,
replica_groups,
use_global_device_ids,
Keyword.take(opts, [:channel_id])
)
|> hd()
Copy link
Contributor

@polvalente polvalente Jan 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's hard match for now instead of hd (i.e. [result] = Value...)
And then add a comment that we might want to surface all_gather as an operation that takes a container of operands instead of a single one.


attributes =
if opts[:channel_id] do
attributes ++ [channel_id: attr_i64(opts[:channel_id])]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use Keyword.put instead of ++

if opts[:channel_id] do
attributes ++ [channel_id: attr_i64(opts[:channel_id])]
else
attributes end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

formatting

end
end

def all_gather([%Value{function: func} | _] = operands, typespec, all_gather_dim, replica_groups, use_global_device_ids, opts \\ []) do
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about channel_id being a required argument and we just pass the value directly?

Comment on lines 481 to 484
if op == :all_gather and not function_exported?(mod, :all_gather, 3) do
raise ArgumentError,
"all_gather/3 is not supported by backend #{inspect(mod)}."
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we remove this, do we have a test verifying this raise? Also, I believe this is already checked elsewhere.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's not, it seems to me that this check should be more general


* `tensor` - The input tensor to gather
* `all_gather_dim` - The dimension along which to gather
* `replica_groups` - 2D list defining how replicas are grouped (required)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this is the terminology we want to surface here. For now, let's make the function all_gather(tensor, opts) and defer the documentation of opts to the specific backend or compiler.

And in EXLA we should add a new section to the moduledoc of EXLA describing Sharding

Copy link
Contributor

@polvalente polvalente left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking great! I think we need more tests in both Nx and EXLA

@polvalente polvalente merged commit 89f65dd into elixir-nx:main Feb 13, 2026
16 of 18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants